package com.jstarcraft.ai.jsat.parameters;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.PriorityQueue;

import com.jstarcraft.ai.jsat.DataSet;
import com.jstarcraft.ai.jsat.classifiers.ClassificationDataSet;
import com.jstarcraft.ai.jsat.classifiers.ClassificationModelEvaluation;
import com.jstarcraft.ai.jsat.classifiers.Classifier;
import com.jstarcraft.ai.jsat.classifiers.WarmClassifier;
import com.jstarcraft.ai.jsat.distributions.Distribution;
import com.jstarcraft.ai.jsat.exceptions.FailedToFitException;
import com.jstarcraft.ai.jsat.regression.RegressionDataSet;
import com.jstarcraft.ai.jsat.regression.RegressionModelEvaluation;
import com.jstarcraft.ai.jsat.regression.Regressor;
import com.jstarcraft.ai.jsat.regression.WarmRegressor;
import com.jstarcraft.ai.jsat.utils.concurrent.ParallelUtils;

import it.unimi.dsi.fastutil.doubles.DoubleArrayList;

/**
 * GridSearch is a simple method for tuning the parameters of a classification
 * or regression algorithm. It naively tries all possible pairs of parameter
 * values given. For this reason, it works best when only a small number of
 * parameters need to be turned. <br>
 * The model it takes must implement the {@link Parameterized} interface. By
 * default, no parameters are selected for optimizations. This is because
 * parameters value ranges are often algorithm specific. As such, the user must
 * specify the parameters and the values to test using the <tt>addParameter</tt>
 * methods.
 * 
 * @author Edward Raff
 * 
 * @see #addParameter(jsat.parameters.DoubleParameter, double[])
 * @see #addParameter(jsat.parameters.IntParameter, int[])
 */
public class GridSearch extends ModelSearch {
    private static final long serialVersionUID = -1987196172499143753L;

    /**
     * The matching list of values we will test. This includes the integer
     * parameters, which will have to be cast back and forth from doubles.
     */
    private List<DoubleArrayList> searchValues;

    /**
     * Use warm starts when possible
     */
    private boolean useWarmStarts = true;

    /**
     * Creates a new GridSearch to tune the specified parameters of a regression
     * model. The parameters still need to be specified by calling
     * {@link #addParameter(jsat.parameters.DoubleParameter, double[]) }
     * 
     * @param baseRegressor the regressor to tune the parameters of
     * @param folds         the number of folds of cross-validation to perform to
     *                      evaluate each combination of parameters
     * @throws FailedToFitException if the base regressor does not implement
     *                              {@link Parameterized}
     */
    public GridSearch(Regressor baseRegressor, int folds) {
        super(baseRegressor, folds);
        searchValues = new ArrayList<>();
    }

    /**
     * Creates a new GridSearch to tune the specified parameters of a classification
     * model. The parameters still need to be specified by calling
     * {@link #addParameter(jsat.parameters.DoubleParameter, double[]) }
     * 
     * @param baseClassifier the classifier to tune the parameters of
     * @param folds          the number of folds of cross-validation to perform to
     *                       evaluate each combination of parameters
     * @throws FailedToFitException if the base classifier does not implement
     *                              {@link Parameterized}
     */
    public GridSearch(Classifier baseClassifier, int folds) {
        super(baseClassifier, folds);
        searchValues = new ArrayList<>();
    }

    /**
     * Copy constructor
     * 
     * @param toCopy the object to copy
     */
    public GridSearch(GridSearch toCopy) {
        super(toCopy);
        this.useWarmStarts = toCopy.useWarmStarts;

        if (toCopy.searchValues != null) {
            this.searchValues = new ArrayList<>();
            for (DoubleArrayList ld : toCopy.searchValues) {
                DoubleArrayList newVals = new DoubleArrayList(ld);
                this.searchValues.add(newVals);
            }
        }
    }

    /**
     * This method will automatically populate the search space with parameters
     * based on which Parameter objects return non-null distributions. Each
     * parameter will be tested with 10 different values<br>
     * <br>
     * Note, using this method with Cross Validation has the potential for
     * over-estimating the accuracy of results if the data set is actually used to
     * for parameter guessing.<br>
     * <br>
     * It is possible for this method to return 0, indicating that no default
     * parameters could be found. The intended interpretation is that there are no
     * parameters that you <i>need</i> to tune to get good performance from the
     * given model. Though there will be cases where the author has simply missed a
     * class.
     *
     * @param data the data set to get parameter estimates from
     * @return the number of parameters added
     */
    public int autoAddParameters(DataSet data) {
        return autoAddParameters(data, 10);
    }

    /**
     * This method will automatically populate the search space with parameters
     * based on which Parameter objects return non-null distributions.<br>
     * <br>
     * Note, using this method with Cross Validation has the potential for
     * over-estimating the accuracy of results if the data set is actually used to
     * for parameter guessing.
     *
     * @param data       the data set to get parameter estimates from
     * @param paramsEach the number of parameters value to try for each parameter
     *                   found
     * @return the number of parameters added
     */
    public int autoAddParameters(DataSet data, int paramsEach) {
        Parameterized obj;
        if (baseClassifier != null)
            obj = (Parameterized) baseClassifier;
        else
            obj = (Parameterized) baseRegressor;
        int totalParms = 0;
        for (Parameter param : obj.getParameters()) {
            Distribution dist;
            if (param instanceof DoubleParameter) {
                dist = ((DoubleParameter) param).getGuess(data);
                if (dist != null)
                    totalParms++;
            } else if (param instanceof IntParameter) {
                dist = ((IntParameter) param).getGuess(data);
                if (dist != null)
                    totalParms++;
            }
        }
        if (totalParms < 1)
            return 0;

        double[] quantiles = new double[paramsEach];
        for (int i = 0; i < quantiles.length; i++)
            quantiles[i] = (i + 1.0) / (paramsEach + 1.0);
        for (Parameter param : obj.getParameters()) {
            Distribution dist;
            if (param instanceof DoubleParameter) {
                dist = ((DoubleParameter) param).getGuess(data);
                if (dist == null)
                    continue;
                double[] vals = new double[paramsEach];
                for (int i = 0; i < vals.length; i++)
                    vals[i] = dist.invCdf(quantiles[i]);

                addParameter((DoubleParameter) param, vals);

            } else if (param instanceof IntParameter) {
                dist = ((IntParameter) param).getGuess(data);
                if (dist == null)
                    continue;
                int[] vals = new int[paramsEach];
                for (int i = 0; i < vals.length; i++)
                    vals[i] = (int) Math.round(dist.invCdf(quantiles[i]));

                addParameter((IntParameter) param, vals);
            }
        }

        return totalParms;
    }

    /**
     * Sets whether or not warm starts are used, but only if the model in use
     * supports warm starts. This is set to {@code true} by default.
     * 
     * @param useWarmStarts {@code true} if warm starts should be used when
     *                      possible, {@code false} otherwise.
     */
    public void setUseWarmStarts(boolean useWarmStarts) {
        this.useWarmStarts = useWarmStarts;
    }

    /**
     *
     * @return {@code true} if warm starts will be used when possible. {@code false}
     *         if they will not.
     */
    public boolean isUseWarmStarts() {
        return useWarmStarts;
    }

    /**
     * Adds a new double parameter to be altered for the model being tuned.
     * 
     * @param param               the model parameter
     * @param initialSearchValues the values to try for the specified parameter
     */
    public void addParameter(DoubleParameter param, double... initialSearchValues) {
        if (param == null)
            throw new IllegalArgumentException("null not allowed for parameter");
        searchParams.add(param);
        DoubleArrayList dl = new DoubleArrayList(initialSearchValues.length);
        for (double d : initialSearchValues)
            dl.add(d);
        Arrays.sort(dl.elements());// convience, only really needed if param is warm
        if (param.isWarmParameter() && !param.preferredLowToHigh())
            Collections.reverse(dl);// put it in the prefered order
        if (param.isWarmParameter())// put it at the front!
            searchValues.add(0, dl);
        else
            searchValues.add(dl);
    }

    /**
     * Adds a new double parameter to be altered for the model being tuned.
     *
     * @param name                the name of the parameter
     * @param initialSearchValues the values to try for the specified parameter
     */
    public void addParameter(String name, double... initialSearchValues) {
        Parameter param;
        param = getParameterByName(name);
        if (!(param instanceof DoubleParameter))
            throw new IllegalArgumentException("Parameter " + name + " is not for double values");

        addParameter((DoubleParameter) param, initialSearchValues);
    }

    /**
     * Adds a new int parameter to be altered for the model being tuned.
     * 
     * @param param               the model parameter
     * @param initialSearchValues the values to try for the specified parameter
     */
    public void addParameter(IntParameter param, int... initialSearchValues) {
        searchParams.add(param);
        DoubleArrayList dl = new DoubleArrayList(initialSearchValues.length);
        for (double d : initialSearchValues)
            dl.add(d);
        Arrays.sort(dl.elements());// convience, only really needed if param is warm
        if (param.isWarmParameter() && !param.preferredLowToHigh())
            Collections.reverse(dl);// put it in the prefered order
        if (param.isWarmParameter())// put it at the front!
            searchValues.add(0, dl);
        else
            searchValues.add(dl);
    }

    /**
     * Adds a new integer parameter to be altered for the model being tuned.
     *
     * @param name                the name of the parameter
     * @param initialSearchValues the values to try for the specified parameter
     */
    public void addParameter(String name, int... initialSearchValues) {
        Parameter param;
        param = getParameterByName(name);
        if (!(param instanceof IntParameter))
            throw new IllegalArgumentException("Parameter " + name + " is not for int values");

        addParameter((IntParameter) param, initialSearchValues);
    }

    @Override
    public void train(final RegressionDataSet dataSet, final boolean parallel) {
        final PriorityQueue<RegressionModelEvaluation> bestModels = new PriorityQueue<>(folds, (RegressionModelEvaluation t, RegressionModelEvaluation t1) -> {
            double v0 = t.getScoreStats(regressionTargetScore).getMean();
            double v1 = t1.getScoreStats(regressionTargetScore).getMean();
            int order = regressionTargetScore.lowerIsBetter() ? 1 : -1;
            return order * Double.compare(v0, v1);
        });

        /**
         * Use this to keep track of which parameter we are altering. Index
         * correspondence to the parameter, and its value corresponds to which value has
         * been used. Increment and carry counts to iterate over all possible
         * combinations.
         */
        int[] setTo = new int[searchParams.size()];
        /**
         * Each model is set to have different combination of parameters. We then train
         * each model to determine the best one.
         */
        final List<Regressor> paramsToEval = new ArrayList<Regressor>();

        while (true) {
            setParameters(setTo);

            paramsToEval.add(baseRegressor.clone());

            if (incrementCombination(setTo))
                break;
        }
        // if we are doing our CV splits ahead of time, get them done now
        final List<RegressionDataSet> preFolded;

        /**
         * Pre-combine our training combinations so that any caching can be re-used
         */
        final List<RegressionDataSet> trainCombinations;

        if (reuseSameCVFolds) {
            preFolded = dataSet.cvSet(folds);
            trainCombinations = new ArrayList<>(preFolded.size());
            for (int i = 0; i < preFolded.size(); i++)
                trainCombinations.add(RegressionDataSet.comineAllBut(preFolded, i));
        } else {
            preFolded = null;
            trainCombinations = null;
        }

        boolean considerWarm = useWarmStarts && baseRegressor instanceof WarmRegressor;
        /**
         * make sure we don't do a warm start if its only supported when trained on the
         * same data but we aren't reuse-ing the same CV splits So we get the truth
         * table
         * 
         * a | b | (a&&b)||¬a T | T | T T | F | F F | T | T F | F | T
         * 
         * where a = warmFromSameDataOnly and b = reuseSameSplit So we can instead use ¬
         * a || b
         */

        if (considerWarm && (!((WarmRegressor) baseRegressor).warmFromSameDataOnly() || reuseSameCVFolds)) {
            /*
             * we want all of the first parameter (which is the warm paramter, taken care of
             * for us) values done in a group. So We can get this by just dividing up the
             * larger list into sub lists, each sub list is adjacent in the original and is
             * the number of parameter values we wanted to try
             */

            ParallelUtils.run(parallel && trainModelsInParallel, paramsToEval.size(), (start, end) -> {
                final List<Regressor> subSet = paramsToEval.subList(start, end);
                Regressor[] prevModels = null;
                for (Regressor r : subSet) {
                    RegressionModelEvaluation rme = new RegressionModelEvaluation(r, dataSet, !trainModelsInParallel && parallel);
                    rme.setKeepModels(true);// we need these to do warm starts!
                    rme.setWarmModels(prevModels);
                    rme.addScorer(regressionTargetScore.clone());
                    if (reuseSameCVFolds)
                        rme.evaluateCrossValidation(preFolded, trainCombinations);
                    else
                        rme.evaluateCrossValidation(folds);
                    prevModels = rme.getKeptModels();
                    synchronized (bestModels) {
                        bestModels.add(rme);
                    }
                }
            });
        } else// regular CV, train a new model from scratch at every step
        {
            ParallelUtils.run(parallel && trainModelsInParallel, paramsToEval.size(), (indx) -> {
                Regressor r = paramsToEval.get(indx);
                RegressionModelEvaluation rme = new RegressionModelEvaluation(r, dataSet, !trainModelsInParallel && parallel);
                rme.addScorer(regressionTargetScore.clone());
                if (reuseSameCVFolds)
                    rme.evaluateCrossValidation(preFolded, trainCombinations);
                else
                    rme.evaluateCrossValidation(folds);
                synchronized (bestModels) {
                    bestModels.add(rme);
                }
            });
        }

        // Now we know the best classifier, we need to train one on the whole data set.
        Regressor bestRegressor = bestModels.peek().getRegressor();// Just re-train it on the whole set
        if (trainFinalModel) {
            // try and warm start the final model if we can
            if (useWarmStarts && bestRegressor instanceof WarmRegressor && !((WarmRegressor) bestRegressor).warmFromSameDataOnly())// last line here needed to make sure we can do this warm train
            {
                WarmRegressor wr = (WarmRegressor) bestRegressor;
                wr.train(dataSet, wr.clone(), parallel);
            } else {
                bestRegressor.train(dataSet, parallel);
            }
        }
        trainedRegressor = bestRegressor;

    }

    @Override
    public void train(final ClassificationDataSet dataSet, final boolean parallel) {
        final PriorityQueue<ClassificationModelEvaluation> bestModels = new PriorityQueue<>(folds, (ClassificationModelEvaluation t, ClassificationModelEvaluation t1) -> {
            double v0 = t.getScoreStats(classificationTargetScore).getMean();
            double v1 = t1.getScoreStats(classificationTargetScore).getMean();
            int order = classificationTargetScore.lowerIsBetter() ? 1 : -1;
            return order * Double.compare(v0, v1);
        });

        /**
         * Use this to keep track of which parameter we are altering. Index
         * correspondence to the parameter, and its value corresponds to which value has
         * been used. Increment and carry counts to iterate over all possible
         * combinations.
         */
        int[] setTo = new int[searchParams.size()];

        /**
         * Each model is set to have different combination of parameters. We then train
         * each model to determine the best one.
         */
        final List<Classifier> paramsToEval = new ArrayList<Classifier>();

        while (true) {
            setParameters(setTo);

            paramsToEval.add(baseClassifier.clone());

            if (incrementCombination(setTo))
                break;
        }
        // if we are doing our CV splits ahead of time, get them done now
        final List<ClassificationDataSet> preFolded;

        /**
         * Pre-combine our training combinations so that any caching can be re-used
         */
        final List<ClassificationDataSet> trainCombinations;

        if (reuseSameCVFolds) {
            preFolded = dataSet.cvSet(folds);
            trainCombinations = new ArrayList<>(preFolded.size());
            for (int i = 0; i < preFolded.size(); i++)
                trainCombinations.add(ClassificationDataSet.comineAllBut(preFolded, i));
        } else {
            preFolded = null;
            trainCombinations = null;
        }

        boolean considerWarm = useWarmStarts && baseClassifier instanceof WarmClassifier;

        /**
         * make sure we don't do a warm start if its only supported when trained on the
         * same data but we aren't reuse-ing the same CV splits So we get the truth
         * table
         * 
         * a | b | (a&&b)||¬a T | T | T T | F | F F | T | T F | F | T
         * 
         * where a = warmFromSameDataOnly and b = reuseSameSplit So we can instead use ¬
         * a || b
         */

        if (considerWarm && (!((WarmClassifier) baseClassifier).warmFromSameDataOnly() || reuseSameCVFolds)) {
            /*
             * we want all of the first parameter (which is the warm paramter, taken care of
             * for us) values done in a group. So We can get this by just dividing up the
             * larger list into sub lists, each sub list is adjacent in the original and is
             * the number of parameter values we wanted to try
             */

            ParallelUtils.run(parallel && trainModelsInParallel, paramsToEval.size(), (start, end) -> {
                final List<Classifier> subSet = paramsToEval.subList(start, end);
                Classifier[] prevModels = null;
                for (Classifier r : subSet) {
                    ClassificationModelEvaluation cme = new ClassificationModelEvaluation(r, dataSet, !trainModelsInParallel && parallel);
                    cme.setKeepModels(true);// we need these to do warm starts!
                    cme.setWarmModels(prevModels);
                    cme.addScorer(classificationTargetScore.clone());
                    if (reuseSameCVFolds)
                        cme.evaluateCrossValidation(preFolded, trainCombinations);
                    else
                        cme.evaluateCrossValidation(folds);
                    prevModels = cme.getKeptModels();
                    synchronized (bestModels) {
                        bestModels.add(cme);
                    }
                }
            });
        } else// regular CV, train a new model from scratch at every step
        {
            ParallelUtils.run(parallel && trainModelsInParallel, paramsToEval.size(), (int indx) -> {
                Classifier toTrain = paramsToEval.get(indx);
                ClassificationModelEvaluation cme = new ClassificationModelEvaluation(toTrain, dataSet, !trainModelsInParallel && parallel);
                cme.addScorer(classificationTargetScore.clone());
                if (reuseSameCVFolds)
                    cme.evaluateCrossValidation(preFolded, trainCombinations);
                else
                    cme.evaluateCrossValidation(folds);
                synchronized (bestModels) {
                    bestModels.add(cme);
                }
            });
        }

        // Now we know the best classifier, we need to train one on the whole data set.
        Classifier bestClassifier = bestModels.peek().getClassifier();// Just re-train it on the whole set
        if (trainFinalModel) {
            // try and warm start the final model if we can
            if (useWarmStarts && bestClassifier instanceof WarmClassifier && !((WarmClassifier) bestClassifier).warmFromSameDataOnly())// last line here needed to make sure we can do this warm train
            {
                WarmClassifier wc = (WarmClassifier) bestClassifier;
                wc.train(dataSet, wc.clone(), parallel);
            } else {
                bestClassifier.train(dataSet, parallel);
            }
        }
        trainedClassifier = bestClassifier;
    }

    @Override
    public GridSearch clone() {
        return new GridSearch(this);
    }

    /**
     * This increments the array used to keep track of which combinations of
     * parameter values have been used.
     * 
     * @param setTo the array of length equal to {@link #searchParams} indicating
     *              which parameters have already been tried
     * @return a boolean indicating <tt>true</tt> if all combinations have been
     *         tried, or <tt>false</tt> if combinations remain to be attempted.
     */
    private boolean incrementCombination(int[] setTo) {
        setTo[0]++;

        int carryPos = 0;

        while (carryPos < setTo.length - 1 && setTo[carryPos] >= searchValues.get(carryPos).size()) {
            setTo[carryPos] = 0;
            setTo[++carryPos]++;
        }

        return setTo[setTo.length - 1] >= searchValues.get(setTo.length - 1).size();
    }

    /**
     * Sets the parameters according to the given array
     * 
     * @param setTo the index corresponds to the parameters, and the value which
     *              parameter value to use.
     */
    private void setParameters(int[] setTo) {
        for (int i = 0; i < setTo.length; i++) {
            Parameter param = searchParams.get(i);
            if (param instanceof DoubleParameter)
                ((DoubleParameter) param).setValue(searchValues.get(i).get(setTo[i]));
            else if (param instanceof IntParameter)
                ((IntParameter) param).setValue(searchValues.get(i).get(setTo[i]).intValue());
        }
    }
}
